from Network.network import Network
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import copy, time
from Network.network_utils import reduce_function, get_acti, pytorch_model, cuda_string
from Network.General.Flat.mlp import MLPNetwork
from Network.General.Conv.conv import ConvNetwork
from Network.General.Factor.Pair.pair import PairNetwork
from Network.General.Factor.Attention.attn_utils import evaluate_key_query, mask_query, init_key_query_args, init_final_args
from Network.General.Factor.Attention.base_attention import BaseMaskedAttentionNetwork


class MultiHeadAttentionMVParallelLayer(Network):
    def __init__(self, args):
        super().__init__(args)
        self.append_keys = args.factor_net.append_keys
        self.append_mask = args.factor_net.append_mask
        self.append_broadcast_mask = args.factor_net.append_broadcast_mask
        
        self.softmax =  nn.Softmax(-1)
        self.model_dim = args.mask_attn.model_dim # the dimension of the keys and queries after network
        # assert(args.embed_dim == args.mask_attn.model_dim * args.mask_attn.num_heads or args.mask_attn.merge_function != "cat")
        self.key_dim = args.embed_dim # the dimension of the key inputs, must equal model_dim * num_heads
        self.query_dim = args.embed_dim
        self.num_heads = args.mask_attn.num_heads
        self.merge_function = args.mask_attn.merge_function
        concatenated_values = self.merge_function == "cat"
        # assert args.embed_dim % self.num_heads == 0 or (not concatenated_values), f"head and key not divisible, key: {args.embed_dim}, head: {self.num_heads}"
        self.head_dim = int(args.embed_dim // self.num_heads) # should be key_dim / num_heads, integer divisible
        self.no_hidden = args.mask_attn.no_hidden
        self.mask_mode = args.mask_attn.mask_mode

        key_query_args = init_key_query_args(args, use_broadcast_mask=False)
        # process one key at a time
        self.key_network = MLPNetwork(key_query_args)
        self.query_network = ConvNetwork(key_query_args)

        value_args = init_key_query_args(args, use_broadcast_mask=True) # only differs if append_keys
        if self.append_keys: self.value_network = PairNetwork(value_args)
        else: self.value_network = ConvNetwork(value_args)

        final_args = init_final_args(args)
        self.final_network = MLPNetwork(final_args)

        self.model = [self.key_network, self.query_network, self.value_network, self.final_network]

    def mask_softmax(self, softmax, mask):
        return (softmax.transpose(-2,-1) * mask.unsqueeze(-1)).transpose(-2,-1)

    def forward(self, key, queries, mask, query_final=False, valid=None):
        # applies the mask at the queries and the values
        # alteratively, apply the mask after the softmax
        batch_size = key.shape[0]
        embed = key
        # uncomment below if queries = queries * mask.unsqueeze(-1) is commented, or we want safer operations that mask out the whole value

        if self.append_keys:
            # print(key.shape, queries.shape, mask.shape, self.value_network)
            # value_network is a pair_net in this case, and mask (and append_mask) is handled there
            values = self.value_network(key.unsqueeze(1), queries, mask.unsqueeze(1), list())[0] # pairnet uses different inputs, performs masking
        else:
            val_queries = mask_query(queries, mask, valid, single_key=True) if self.mask_mode == "query" else queries
            if self.append_broadcast_mask:
                broadcast_mask = 1-mask.unsqueeze(-1).broadcast_to(mask.shape[0], mask.shape[-1], self.append_broadcast_mask)
                val_queries = torch.cat([val_queries, broadcast_mask], dim=-1)
            values = self.value_network(val_queries.transpose(-2,-1)).transpose(1,2)
        
        # perform key query embeddings for attention
        key = self.key_network(key)
        queries = self.query_network(queries.transpose(-2,-1)).transpose(-2,-1)

        # print(key.shape, values.shape, queries.shape, self.num_heads, self.model_dim)
        values = values.reshape(batch_size, queries.shape[1], self.num_heads, self.model_dim).transpose(1,2)
        key, queries = key.reshape(batch_size, self.num_heads, self.model_dim, 1), queries.reshape(batch_size, -1, self.num_heads, self.model_dim).transpose(1,2)
        if self.mask_mode == "attn": weights = evaluate_key_query(self.softmax, key, queries, mask, single_key=True, use_bernoulli=query_final and self.ap.bernoilli_weights)
        else: weights = evaluate_key_query(self.softmax, key, queries, mask, single_key=True)
        if query_final: # preserve the queries
            # batch x heads x queries x 1 * batch x heads x queries x model_dim = batch x heads x queries x model_dim
            values = (weights.unsqueeze(-1) * values )
            if self.merge_function == 'cat': values = values.transpose(3,2) #
            # if not cat: batch x keys x queries x model_dim (merges the heads)
            # if cat: batch x heads x model dim x queries -> batch x heads * model dim x queries (flip queries back)
            values = reduce_function(self.merge_function, values, dim=1) # reduces the heads
            if self.merge_function == 'cat': values = values.transpose(1,2)
            # print("merged", values.shape)
            # values = reduce_function(self.merge_function, values.transpose(1,2), dim=2).transpose(1,2) # batch x keys x queries x model_dim merges with the same function as all the others
            # values of shape batch x queries x final dimension
            # weights of shape batch x queries
        else:
            # batch x heads x 1 x queries * batch x heads x queries x model_dim = batch x heads x 1 x model_dim
            values = torch.matmul(weights.unsqueeze(-2), values)[:,:,0,:]
            values = reduce_function(self.merge_function, values, dim=1)
            values = self.final_network(values)
            embed = values
            # values of shape batch x final dimension
            # weights of shape batch x queries
        # print(values.shape, weights.shape, embed.shape)
        return values, weights, embed

class MaskedAttentionNetwork(BaseMaskedAttentionNetwork):
    def __init__(self, args):
        # TODO: initialization is exactly the same as parallel attention
        super().__init__(args, MultiHeadAttentionMVParallelLayer)
        self.fp = args.factor
    
    def compute_attention(self, key, query, mask):
        values, attns = list(), list() # the final output values
        for i in range(key.shape[1]):
            single_key = key[:,i]
            value, attn, embed = self.multi_head_attention(single_key, query, mask[:,i,:])
            values.append(value)
            attns.append(attn)
            embed = embed
        # [batch size, num_keys, num_queries, embed_dim], [batch size, num_keys, num_queries]
        return torch.stack(values, dim=1), torch.stack(attns, dim=1), embed